import numpy as np
from PIL import Image
import cv2

# OUTPUT_ZOOM = 1.0        # how much to zoom output relative to *original* image
# OUTPUT_DPI = 300         # just affects stated DPI of PNG, not appearance
# REMAP_DECIMATE = 16      # downscaling factor for remapping image
# FOCAL_LENGTH = 1.2       # normalized focal length of camera

# # default intrinsic parameter matrix
# K = np.array([
#     [FOCAL_LENGTH, 0, 0],
#     [0, FOCAL_LENGTH, 0],
#     [0, 0, 1]], dtype=np.float32)

# def round_nearest_multiple(i, factor):
#     i = int(i)
#     rem = i % factor
#     if not rem:
#         return i
#     else:
#         return i + factor - rem

# def project_xy(xy_coords):

#     # get cubic polynomial coefficients given
#     #
#     #  f(0) = 0, f'(0) = alpha
#     #  f(1) = 0, f'(1) = beta
#     # -0.3205797803925317   0.14943099006876762   \
#     # [ 0.12316831 -0.07648575  0.04499854]   \
#     # [-0.59512168 -0.97361331  1.13721823]
#     alpha = -0.3205797803925317
#     beta = 0.14943099006876762 
#     poly = np.array([
#         alpha + beta,
#         -2*alpha - beta,
#         alpha,
#         0])

#     RVEC = np.array([ 0.12316831, -0.07648575, 0.04499854])
#     TVEC = np.array([-0.59512168, -0.97361331, 1.13721823])
    
#     xy_coords = xy_coords.reshape((-1, 2))
#     z_coords = np.polyval(poly, xy_coords[:, 0])

#     objpoints = np.hstack((xy_coords, z_coords.reshape((-1, 1))))

#     image_points, _ = cv2.projectPoints(objpoints,
#                                         RVEC,
#                                         TVEC,
#                                         K, np.zeros(5))

#     return image_points


# def project_xy_(xy_coords):

#     # get cubic polynomial coefficients given
#     #
#     #  f(0) = 0, f'(0) = alpha
#     #  f(1) = 0, f'(1) = beta
#     # -0.3205797803925317   0.14943099006876762   \
#     # [ 0.12316831 -0.07648575  0.04499854]   \
#     # [-0.59512168 -0.97361331  1.13721823]
#     alpha = -0.3205797803925317
#     beta = 0.14943099006876762 
#     poly = np.array([
#         alpha + beta,
#         -2*alpha - beta,
#         alpha,
#         0])

#     RVEC = np.array([ 0.12316831, -0.07648575, 0.04499854])
#     TVEC = np.array([-0.59512168, -0.97361331, 1.13721823])
    
#     rvecM = cv2.Rodrigues(RVEC)[0]
#     print(rvecM)
#     rvecM_inv = np.linalg.pinv(rvecM)
#     K_inv = np.linalg.pinv(K)
#     xy_coords = xy_coords.reshape((-1, 2))
#     z_coords = np.polyval(poly, xy_coords[:, 0])

#     objpoints = np.hstack((xy_coords, z_coords.reshape((-1, 1))))
#     objpoints[:, 2] = 1
#     tmp = rvecM_inv.dot(K_inv.dot(objpoints.T)).T
#     tmp1 = rvecM_inv.dot(TVEC).T  #(z_const)
#     print(tmp.shape, " ", tmp1.shape)
#     s = (tmp1[2] + z_coords) / tmp[:, 2] 
#     print("s s s {} {}".format(s.shape, s) )
#     image_points = rvecM_inv.dot((s.reshape((-1, 1)) * (K_inv.dot(objpoints.T)).T - TVEC).T )
#     print(image_points, " ", image_points.dtype)    
#     image_points = (image_points.T)[..., :2]#/((image_points.T)[..., 2]).reshape((-1, 1))
#     return image_points.reshape((-1,1,2)).copy().astype(np.float32)

# def norm2pix(shape, pts, as_integer):
#     height, width = shape[:2]
#     scl = max(height, width)*0.5
#     offset = np.array([0.5*width, 0.5*height],
#                       dtype=pts.dtype).reshape((-1, 1, 2))
#     rval = pts * scl + offset
#     if as_integer:
#         return (rval + 0.5).astype(int)
#     else:
#         return rval

# def remap_image(name, img, page_dims=[1., 1.]):

#     height = 0.5 * page_dims[1] * OUTPUT_ZOOM * img.shape[0]
#     height = round_nearest_multiple(height, REMAP_DECIMATE)

#     width = round_nearest_multiple(img.shape[1] * page_dims[0] / page_dims[1],
#                                    REMAP_DECIMATE)

#     print('  output will be {}x{}'.format(width, height))

#     height_small = height / REMAP_DECIMATE
#     width_small = width / REMAP_DECIMATE

#     page_x_range = np.linspace(0, page_dims[0], width_small)
#     page_y_range = np.linspace(0, page_dims[1], height_small)

#     page_x_coords, page_y_coords = np.meshgrid(page_x_range, page_y_range)

#     page_xy_coords = np.hstack((page_x_coords.flatten().reshape((-1, 1)),
#                                 page_y_coords.flatten().reshape((-1, 1))))

#     page_xy_coords = page_xy_coords.astype(np.float32)

#     print(page_xy_coords)
#     image_points = project_xy_(page_xy_coords)
#     print(image_points)
#     print("*****************")
#     image_points = norm2pix(img.shape, image_points, False)

#     print(image_points)
#     image_x_coords = image_points[:, 0, 0].reshape(page_x_coords.shape)
#     image_y_coords = image_points[:, 0, 1].reshape(page_y_coords.shape)

#     image_x_coords = cv2.resize(image_x_coords, (width, height),
#                                 interpolation=cv2.INTER_CUBIC)

#     image_y_coords = cv2.resize(image_y_coords, (width, height),
#                                 interpolation=cv2.INTER_CUBIC)

#     # print(image_x_coords[:,300])
#     print(image_y_coords)
#     img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
#     print(img_gray)

#     remapped = cv2.remap(img_gray, image_x_coords-1300, image_y_coords-2000,
#                          cv2.INTER_CUBIC,
#                          None, cv2.BORDER_REPLICATE)

#     pil_image = Image.fromarray(remapped)
#     # pil_image = pil_image.convert('1')

#     path = './warps/' 
#     pil_image.save(path + name, dpi=(OUTPUT_DPI, OUTPUT_DPI))

#     print("remapped ", remapped.shape)
#     cv2.imshow("remapped", remapped)
#     cv2.imwrite('remapped.jpg', remapped)
#     cv2.imshow("ori.jpg", img_gray)
#     cv2.waitKey()
# alpha********************
# 0.06813573072127965   0.28235706974708863
# beta********************
# (2304, 1296, 3) fx  1706.25
# (2985984,)   (2985984,)
# (2304, 1296, 3)
# *********
# (2304, 1296, 3)
# <class 'numpy.ndarray'>
# [[ 8.77856475e-01 -1.61952065e-01  1.52004578e+02]
#  [ 7.18365011e-02  8.47882098e-01 -3.68611794e+01]
#  [-3.04905217e-05 -7.12470073e-05  1.00000000e+00]]

# def addCurvedWarp(image):
#     remap_image("1_10.jpg", image)

angle = 25
angley = np.random.randint(-angle, angle)
anglex = np.random.randint(-angle, angle)
anglez = 0#np.random.randint(0, angle)
fov = np.random.randint(40, 60)
r = 0
w, h = 0, 0
def rad(x):
    return x * np.pi / 180

def get_warpR():
    global anglex,angley,anglez,fov,w,h,r
    # 镜头与图像间的距离，21为半可视角，算z的距离是为了保证在此可视角度下恰好显示整幅图像
    z = np.sqrt(w ** 2 + h ** 2) / 2 / np.tan(rad(fov / 2))
    # 齐次变换矩阵
    rx = np.array([[1, 0, 0, 0],
                   [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0],
                   [0, -np.sin(rad(anglex)), np.cos(rad(anglex)), 0, ],
                   [0, 0, 0, 1]], np.float32)
 
    ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0],
                   [0, 1, 0, 0],
                   [-np.sin(rad(angley)), 0, np.cos(rad(angley)), 0, ],
                   [0, 0, 0, 1]], np.float32)
 
    rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0],
                   [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0],
                   [0, 0, 1, 0],
                   [0, 0, 0, 1]], np.float32)
 
    r = rx.dot(ry).dot(rz)
 
    # 四对点的生成
    pcenter = np.array([h / 2, w / 2, 0, 0], np.float32)
 
    p1 = np.array([0, 0, 0, 0], np.float32) - pcenter
    p2 = np.array([w, 0, 0, 0], np.float32) - pcenter
    p3 = np.array([0, h, 0, 0], np.float32) - pcenter
    p4 = np.array([w, h, 0, 0], np.float32) - pcenter
 
    dst1 = r.dot(p1)
    dst2 = r.dot(p2)
    dst3 = r.dot(p3)
    dst4 = r.dot(p4)
 
    list_dst = [dst1, dst2, dst3, dst4]
 
    org = np.array([[0, 0],
                    [w, 0],
                    [0, h],
                    [w, h]], np.float32)
 
    dst = np.zeros((4, 2), np.float32)
 
    # 投影至成像平面
    for i in range(4):
        dst[i, 0] = list_dst[i][0] * z / (z - list_dst[i][2]) + pcenter[0]
        dst[i, 1] = list_dst[i][1] * z / (z - list_dst[i][2]) + pcenter[1]
 
    warpR = cv2.getPerspectiveTransform(org, dst)
    return warpR

def AddBorder(image, top=None, bottom=None, left=None, right=None, constant_v=True):
    factor_y, factor_x = 0.1, 0.1  # border factor
    top = int(factor_y * image.shape[0])  # shape[0] = rows
    bottom = top
    # top = bottom // 4
    left = int(factor_x * image.shape[1])  # shape[1] = cols
    right = left
    # left = int(right/8)
    
    value = [0,0,0]
    borderType = cv2.BORDER_CONSTANT
    bo1 = cv2.copyMakeBorder(image, top, bottom, left, right, borderType, None, value)
    
    borderType = cv2.BORDER_REPLICATE
    bo2 = cv2.copyMakeBorder(image, top, bottom, left, right, borderType, None, value)

    print("ori_image:{}, addborder:{}".format(image.shape, bo1.shape))
    return bo1 if constant_v else bo2
    
def addCurvedWarp_(image):
    dst = np.zeros(image.shape).astype(image.dtype)
    print("Before adding warp shape:image/{},dst/{}".format(image.shape, dst.shape))

    alpha = np.random.uniform(-0.5,0.5)
    beta = np.random.uniform(-1,0.5)
    print("------------alpha:{}, beta:{}--------------".format(alpha, beta))
    # alpha = 0.3689624505532999        
    # beta = -1.      

    poly = np.array([
        alpha + beta,
        -2*alpha - beta,
        alpha,
        0])

    
    height, width = image.shape[:2]

    cx, cy, fx, fy = width/2, height/2, 930.965, 1130.884
    fx = 500 * (width / 640.0)
    fy = 500 * (height / 480.0)
    fx = (fx + fy) / 2.0
    fy = fx
    
    print(image.shape, "fx ", fx)
    mask = np.ones((height, width))
    idx = np.where(mask==1)
    # print(idx)
    # xy = []
    # for i in range(idx):
    #     xy.append([i//width, i%width])
    #     print(xy)
    #     input()
    # xy = np.array(xy)
    # print(xy.shape, "xxxxxxxx")
    Y = (idx[0] - cy) / fy             #摄像机坐标系
    X = (idx[1] - cx) / fx
    z_coords = np.polyval(poly, X)

    newX = X*(1+z_coords)   #畸变 z_coords*f
    newY = Y*(1+z_coords)

    u = newX * fx + cx     #图像坐标系
    v = newY * fy + cy

    u0 = u.astype(np.int32)
    v0 = v.astype(np.int32)
    u1 = u.astype(np.int32) + 1
    v1 = v.astype(np.int32) + 1

    dx = u - u0
    dy = v - v0
    weight = [(1-dx) * (1- dy), dx*(1-dy), (1-dx)*dy, dx*dy]
    print("u0:{}, v0:{}".format(u0.shape, v0.shape))

    for i in range(height):
        for j in range(width):
            pixel = i*width+j
            if u0[pixel]>=0 and u1[pixel]<width and v0[pixel]>=0 and v1[pixel]<height:
                dst[i, j, :] = weight[0][pixel] * image[v0[pixel], u0[pixel], :] + weight[1][pixel] * image[v0[pixel], u1[pixel], :] \
                    + weight[2][pixel] * image[v1[pixel], u0[pixel], :] + weight[3][pixel] * image[v1[pixel], u1[pixel], :]
    # 旋转,平移
    # corners = np.float32([[-0.67136661, -0.98943378],
    # [ 0.62650339, -1.01843302],
    # [ 0.67136661,  0.98943378],
    # [-0.62650339,  1.01843302]] ) 
    # page_width = np.linalg.norm(corners[1] - corners[0])
    # page_height = np.linalg.norm(corners[-1] - corners[0])
    # corners_object3d = np.float32([
    #     [0, 0],
    #     [page_width, 0],
    #     [page_width, page_height],
    #     [0, page_height]])
    # warpR = cv2.getPerspectiveTransform(corners_object3d, corners)

    global w,h
    w = width
    h = height
    warpR = get_warpR()
    dst1 = cv2.warpPerspective(dst, warpR, (width, height))

    print("After adding warp shape:image/{},dst/{}".format(image.shape, dst1.shape))

    # cv2.imshow("ori", image)
    # cv2.imshow("dst1", dst1)
    # cv2.imshow("dst", dst)
    cv2.imwrite("dst.jpg", dst)
    # cv2.waitKey()

    return dst, dst1

'''
11_14      77_10
29_18      85_15
38_9       81_10
64_13      88_20
65_10      91_9
67_20
70_10
73_11
84_9
'''
# borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py
def get_files(img_dir):
    imgs, masks, xmls = list_files(img_dir)
    return imgs, masks, xmls

def list_files(in_path):
    img_files = []
    mask_files = []
    gt_files = []
    for (dirpath, dirnames, filenames) in os.walk(in_path):
        for file in filenames:
            filename, ext = os.path.splitext(file)
            ext = str.lower(ext)
            if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
                img_files.append(os.path.join(dirpath, file))
            elif ext == '.bmp':
                mask_files.append(os.path.join(dirpath, file))
            elif ext == '.xml' or ext == '.gt' or ext == '.txt':
                gt_files.append(os.path.join(dirpath, file))
            elif ext == '.zip':
                continue
    # img_files.sort()
    # mask_files.sort()
    # gt_files.sort()
    return img_files, mask_files, gt_files

if __name__ == "__main__":
    import sys
    import os
    image_list = None 
    if os.path.isdir(sys.argv[1]):
        image_list, _, _ = get_files(sys.argv[1])
    else:
        image_list = [sys.argv[1]]
    for k, imgfile in enumerate(image_list):
        print("process {}th image {}".format(k, imgfile))  
        image_path = imgfile #"./image/91_9.jpg"
        image = cv2.imread(image_path)

        # image = cv2.resize(image, (width, height),
        #                              interpolation=cv2.INTER_CUBIC)
        image = AddBorder(image)
        dst, dst1 = addCurvedWarp_(image)
        cv2.imwrite("curve_warp/cubic_" + image_path.split('/')[-1], dst)
        cv2.imwrite("curve_warp/cubicRT_" + image_path.split('/')[-1], dst1)